// This code is a modified version of the C++ source underlying
// quanteda.textmodels::textmodel_wordfish(), rewritten by the author
// to allow user-specified initial values.

// [[Rcpp::depends(RcppArmadillo)]]

#include <math.h>
#include <RcppArmadillo.h>
using namespace Rcpp;
using namespace arma;

// [[Rcpp::export]]
Rcpp::List qatd_cpp_wordfish_dense_boot(mat wfm, colvec dir, colvec priors, colvec tol, bool abs_err, 
                                        colvec init_alpha, colvec init_psi, colvec init_beta, colvec init_theta){

    // DEFINE INPUTS

    mat Y = wfm;
    colvec priorvec = priors;
    colvec tolvec = tol;
    colvec dirvec = dir;

    double priorprecalpha = priorvec(0);
    double priorprecpsi = priorvec(1);
    double priorprecbeta = priorvec(2);
    double priorprectheta = priorvec(3);

    int N = Y.n_rows;
    int K = Y.n_cols;

    // SET INITIAL VALUES
    colvec alpha = init_alpha;
    colvec psi = init_psi;
    colvec beta = init_beta;
    colvec theta = init_theta;

    // Create temporary variables
    mat pars = mat(2,1);
    mat newpars = mat(2,1);
    mat G = mat(2,1);
    mat H = mat(2,2);
    double loglambdaik;
    colvec lambdai = colvec(K);
    colvec lambdak = colvec(N);
    double stepsize = 1.0;
    double cc = 0.0;
    int inneriter = 0;
    int outeriter = 0;

    double lastlp = -2000000000000.0;
    double lp = -1.0 * (sum(0.5 * ((alpha % alpha) * (priorprecalpha))) +
                        sum(0.5 * ((psi % psi) * (priorprecpsi))) +
                        sum(0.5 * ((beta % beta)*(priorprecbeta))) +
                        sum(0.5 * ((theta % theta)*(priorprectheta))));

    for (int i = 0; i < N; i++){
        for (int k=0; k < K; k++){
            loglambdaik = alpha(i) + psi(k) + beta(k) * theta(i);
            lp = lp + loglambdaik * Y(i,k) - exp(loglambdaik);
        }
    }

    // BEGIN WHILE LOOP
    double err = (abs_err == true) ? fabs(lp - lastlp) : (lp - lastlp);
    while ((err > tolvec(0)) && outeriter < 100) {
        outeriter++;

        // UPDATE WORD PARAMETERS
        for (int k = 0; k < K; k++) {
            cc = 1;
            inneriter = 0;
            if (outeriter == 1) stepsize = 0.5;
            while ((cc > tolvec(1)) && inneriter < 10){
                inneriter++;
                lambdak = exp(alpha + psi(k) + beta(k) * theta);
                G(0,0) = sum(Y.col(k) - lambdak) - psi(k) * (priorprecpsi);
                G(1,0) = sum(theta % (Y.col(k) - lambdak)) - beta(k) * (priorprecbeta);
                H(0,0) = -sum(lambdak) - priorprecpsi;
                H(1,0) = -sum(theta % lambdak);
                H(0,1) = H(1,0);
                H(1,1) = -sum((theta % theta) % lambdak) - priorprecbeta;
                pars(0,0) = psi(k);
                pars(1,0) = beta(k);
                newpars(0,0) = pars(0,0) - stepsize * (H(1,1) * G(0,0) - H(0,1) * G(1,0)) / (H(0,0) * H(1,1) - H(0,1) * H(1,0));
                newpars(1,0) = pars(1,0) - stepsize * (H(0,0) * G(1,0) - H(1,0) * G(0,0)) / (H(0,0) * H(1,1) - H(0,1) * H(1,0));
                psi(k) = newpars(0,0);
                beta(k) = newpars(1,0);
                cc = as_scalar(max(abs(newpars - pars)));
                stepsize = 1.0;
            }
        }

        // UPDATE DOCUMENT PARAMETERS
        for (int i = 0; i < N; i++){
            cc = 1;
            inneriter = 0;
            if (outeriter == 1) stepsize = 0.5;
            while ((cc > tolvec(1)) && inneriter < 10){
                inneriter++;
                lambdai = exp(alpha(i) + psi + beta * theta(i));
                G(0,0) = sum(trans(Y.row(i)) - lambdai) - alpha(i) * priorprecalpha;
                G(1,0) = sum(beta % (trans(Y.row(i)) - lambdai)) - theta(i) * priorprectheta;
                H(0,0) = -sum(lambdai) - priorprecalpha;
                H(1,0) = -sum(beta % lambdai);
                H(0,1) = H(1,0);
                H(1,1) = -sum((beta % beta) % lambdai) - priorprectheta;
                pars(0,0) = alpha(i);
                pars(1,0) = theta(i);
                newpars(0,0) = pars(0,0) - stepsize * (H(1,1) * G(0,0) - H(0,1) * G(1,0)) / (H(0,0) * H(1,1) - H(0,1) * H(1,0));
                newpars(1,0) = pars(1,0) - stepsize * (H(0,0) * G(1,0) - H(1,0) * G(0,0)) / (H(0,0) * H(1,1) - H(0,1) * H(1,0));
                alpha(i) = newpars(0,0);
                theta(i) = newpars(1,0);
                cc = as_scalar(max(abs(newpars - pars)));
                stepsize = 1.0;
            }
        }

        alpha = alpha - mean(alpha);
        theta = (theta - mean(theta))/stddev(theta);

        // CHECK LOG-POSTERIOR FOR CONVERGENCE
        lastlp = lp;
        lp = -1.0 * (sum(0.5 * ((alpha % alpha) * (priorprecalpha))) + sum(0.5 * ((psi % psi) * (priorprecpsi))) + sum(0.5 * ((beta % beta)*(priorprecbeta))) + sum(0.5 * ((theta % theta)*(priorprectheta))));
        for (int i = 0; i < N; i++){
            for (int k = 0; k < K; k++){
                loglambdaik = alpha(i) + psi(k) + beta(k) * theta(i);
                lp = lp + loglambdaik * Y(i,k) - exp(loglambdaik);
            }
        }
        // Rprintf("%d: %f2\\n",outeriter,lp);
        //Rcout<<"outeriter="<<outeriter<<"  lp - lastlp= "<<lp - lastlp<<std::endl;
        err = (abs_err == true) ? fabs(lp - lastlp) : (lp - lastlp);
        // END WHILE LOOP
    }

    // Fix Global Polarity

    // added the -1 because C counts from ZERO...  -- KB
    if (theta(dirvec(0) - 1) > theta(dirvec(1) - 1)) {
        beta = -beta;
        theta = -theta;
    }

    // DEFINE OUTPUT

    return Rcpp::List::create(Rcpp::Named("theta.boot") = wrap(theta));

}
